package com.iflytek.jzcpx.procuracy.tools.ocr;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URLEncoder;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import cn.hutool.core.thread.ThreadFactoryBuilder;
import cn.hutool.http.HttpRequest;
import com.alibaba.fastjson.JSONObject;
import com.iflytek.jzcpx.procuracy.common.result.Result;
import com.iflytek.jzcpx.procuracy.common.util.IdWokerUtil;
import com.iflytek.jzcpx.procuracy.common.util.JSONUtil;
import com.iflytek.jzcpx.procuracy.common.util.ThreadPoolUtils;
import com.iflytek.jzcpx.procuracy.tools.common.enums.ToolsContstants;
import com.yomahub.tlog.core.thread.TLogInheritableTask;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.FileSystemResource;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;

/**
 * OCR 客户端
 *
 * @author <a href=mailto:ktyi@iflytek.com>伊开堂</a>
 * @date 2019-08-06 18:50
 */
public class OcrClient {
    private static final Logger logger = LoggerFactory.getLogger(OcrClient.class);

    private static final Map<String, String> DEFAULT_PARAMS = new HashMap<>();
    private static final JSONObject DEFAULT_PARAMS_JSON=new JSONObject();
    static {
        DEFAULT_PARAMS.put("pagetype", "page");
        DEFAULT_PARAMS.put("funclist", "");
        DEFAULT_PARAMS.put("resultlevel", "3");

        DEFAULT_PARAMS_JSON.put("pagetype", "page");
        //DEFAULT_PARAMS_JSON.put("funclist", "ocr,od");
        DEFAULT_PARAMS_JSON.put("funclist", "");
        //od 朝向检测 ；ic 图片分类； er 证据检出（签章、收印、插图、手写体）
        DEFAULT_PARAMS_JSON.put("resultlevel", "3");
    }

    /** ocr 接口地址 */
    private String SERVER_URL;

    private int concurrency = 32;

    private int retry = 3;

    /** ocr 引擎请求线程池 */
    private ThreadPoolExecutor poolExecutor;

    private ThreadPoolExecutor buildExecutor(final int corePoolSize) {
        return new ThreadPoolExecutor(corePoolSize, corePoolSize, 0, TimeUnit.MINUTES,
                                      new PriorityBlockingQueue<>(10000, new Comparator<Runnable>() {
                                          @Override
                                          public int compare(Runnable o1, Runnable o2) {
                                              if (o1 == null && o2 == null) {
                                                  return 0;
                                              }
                                              else if (o1 == null) {
                                                  return -1;
                                              }
                                              else if (o2 == null) {
                                                  return 1;
                                              }
                                              else {
                                                  int p1 = ((PriorityFuture<?>) o1).getPriority().getLevel();
                                                  int p2 = ((PriorityFuture<?>) o2).getPriority().getLevel();

                                                  return p2 - p1;
                                              }
                                          }
                                      }), ThreadFactoryBuilder.create().setNamePrefix("OcrEngine-Requester-").build(),
                                      new ThreadPoolExecutor.CallerRunsPolicy()) {

            @Override
            protected <T> RunnableFuture<T> newTaskFor(Callable<T> callable) {
                RunnableFuture<T> newTaskFor = super.newTaskFor(callable);
                return new PriorityFuture<T>(newTaskFor, ((RequestTask<T>) callable).getPriority());
            }
        };
    }

    /** 处理实际请求任务的对象 */
    private abstract static class RequestTask<V> implements Callable<V> {
        private final RequestPriority priority;

        public RequestPriority getPriority() {
            return priority;
        }

        RequestTask(RequestPriority priority) {
            this.priority = priority;
        }
    }
    /** 线程池队列中的任务元素 */
    private static class PriorityFuture<T> extends TLogInheritableTask implements RunnableFuture<T> {
        private RunnableFuture<T> src;
        private RequestPriority priority;

        public PriorityFuture(RunnableFuture<T> other, RequestPriority priority) {
            this.src = other;
            this.priority = priority;
        }

        public RequestPriority getPriority() {
            return priority;
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            return src.cancel(mayInterruptIfRunning);
        }

        @Override
        public boolean isCancelled() {
            return src.isCancelled();
        }

        @Override
        public boolean isDone() {
            return src.isDone();
        }

        @Override
        public T get() throws InterruptedException, ExecutionException {
            return src.get();
        }

        @Override
        public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
            return src.get();
        }

        @Override
        public void runTask() {
            logger.info("run of {}", priority);
            src.run();
        }
    }
    /** 任务优先级枚举 */
    public enum RequestPriority {
        /** 低优先级*/
        LOW(1),
        /** 正常优先级*/
        NORMAL(3),
        /** 高优先级*/
        HIGH(5),
        ;

        private int level;

        RequestPriority(int level) {
            this.level = level;
        }

        public int getLevel() {
            return level;
        }
    }

    @Autowired
    private RestTemplate restTemplate;

    public OcrClient(String server, String path, Integer coreSize, Integer retry) {
        this.SERVER_URL = StringUtils.removeEnd(server, "/") + StringUtils.prependIfMissing(path, "/");
        if (coreSize != null && coreSize > 0) {
            this.concurrency = coreSize;
        }
        if (retry != null && retry > 0) {
            this.retry = retry;
        }
        this.poolExecutor = buildExecutor(concurrency);
        logger.info("创建 Ocr 客户端, 接口地址: {}, 最大并发: {}", SERVER_URL, concurrency);
    }

    public Result<String> request(File file) throws Exception {
        return request(IdWokerUtil.nextId() + "", file, DEFAULT_PARAMS);
    }

    public Result<String> request(File file, Map<String, String> ocrParams) throws Exception {
        return request(IdWokerUtil.nextId() + "", file, ocrParams);
    }

    public Result<String> request(String trackId, File file) throws Exception {
        return request(trackId, file, DEFAULT_PARAMS);
    }

    public Result<String> request(String trackId, File file, Map<String, String> params) throws Exception {
        String serviceUrl = StringUtils.replace(this.SERVER_URL, "{trackId}", trackId);

        Map<String, Object> paramMap = new HashMap<>();
        paramMap.put("picFile", file);
        if (params != null) {
            paramMap.put("params", JSONUtil.toStrDefault(params));
        }

        String resp = poolExecutor.submit(new RequestTask<String>(RequestPriority.NORMAL) {
            @Override
            public String call() throws Exception {
                return HttpRequest.post(serviceUrl).body(JSONUtil.toStrDefault(paramMap))
                                  .header("Content-Type", "application/json; charset=UTF-8")
                                  .timeout(ToolsContstants.ENGINE_REQUEST_TIMEEOUT).execute().body();
            }
        }).get();

        return Result.success(JSONUtil.getProp(JSONUtil.toJsonNode(resp), "body", String.class));
    }

    /**
     * ocr
     * @param file
     * @return
     * @throws Exception
     */
    public Result<String> requestRestTemplate(File file) throws Exception {
        return requestRestTemplate(file, null);
    }

    /**
     * ocr
     * @param file
     * @param params
     * @return
     * @throws Exception
     */
    public Result<String> requestRestTemplate(File file,JSONObject params) throws Exception {
        return requestWithRetry(file, params);
    }

    private Result<String> requestWithRetry(File file, JSONObject params)
            throws InterruptedException, ExecutionException, IOException {
        String trackId = UUID.randomUUID().toString();
        String serviceUrl = StringUtils.replace(this.SERVER_URL, "{trackId}", trackId);
        serviceUrl += "?params={params}";

        if (params == null) {
            params = DEFAULT_PARAMS_JSON;
        }

        HttpHeaders headers = new HttpHeaders();
        //headers.setContentType(MediaType.MULTIPART_FORM_DATA_VALUE.toString());
        headers.add("Content-Type",MediaType.MULTIPART_FORM_DATA_VALUE);
        FileSystemResource resource = new FileSystemResource(file);
        MultiValueMap<String, Object> form = new LinkedMultiValueMap<>();
        form.add("picFile", resource);

        HttpEntity<MultiValueMap<String, Object>> httpEntity = new HttpEntity<>(form, headers);
        String finalServiceUrl = serviceUrl;
        String paramJson = params.toJSONString();

        logger.info("提交OCR引擎识别请求, url: {}, poolExecutor: {}", serviceUrl, ThreadPoolUtils.poolInfo(poolExecutor));
        String resp = poolExecutor.submit(new RequestTask<String>(RequestPriority.NORMAL) {
            @Override
            public String call() throws Exception {
                long start = System.currentTimeMillis();
                logger.info("开始请求OCR引擎, serviceUrl: {}, poolExecutor: {}", finalServiceUrl, ThreadPoolUtils.poolInfo(poolExecutor));
                int requestCount = 1;
                String respStr = null;
                do {
                    long startInner = System.currentTimeMillis();
                    try {
                        respStr = restTemplate.postForObject(finalServiceUrl, httpEntity, String.class, paramJson);
                    }
                    catch (Exception e) {
                        logger.warn("请求OCR引擎异常", e);
                    } finally {
                        logger.info("第{}次请求OCR引擎耗时: {}ms", requestCount, System.currentTimeMillis() - startInner);
                    }
                } while (StringUtils.isEmpty(respStr) && requestCount++ < retry);

                logger.info("请求OCR引擎结束, 总耗时: {}ms, 结果字节数: {}", System.currentTimeMillis() - start,
                            StringUtils.length(respStr));
                return respStr;
            }
        }).get();

        logger.info("请求OCR引擎完成, poolExecutor: {}", ThreadPoolUtils.poolInfo(poolExecutor));
        return Result.success(JSONUtil.getProp(JSONUtil.toJsonNode(resp), "body", String.class));
    }

    public Result<String> request(String base64Img) throws Exception {
        return request(String.valueOf(IdWokerUtil.nextId()), base64Img);
    }

    public Result<String> request(String base64Img, RequestPriority priority) throws Exception {
        return request(String.valueOf(IdWokerUtil.nextId()), base64Img, null, priority);
    }

    public Result<String> request(String trackId, String base64Img, Map<String, String> params,
            RequestPriority priority) throws Exception {
        // /tuling/ocr/v2/base64/{trackId}
        String serviceUrl = StringUtils
                .replace(StringUtils.replace(this.SERVER_URL, "{trackId}", trackId), "file", "base64");

        Map<String, String> paramMap = DEFAULT_PARAMS;
        if (params != null) {
            paramMap.putAll(params);
        }

        StringBuilder sb = new StringBuilder();
        for (Map.Entry<String, String> entry : paramMap.entrySet()) {
            String key = entry.getKey();
            String value = entry.getValue();
            sb.append(key).append("=").append(URLEncoder.encode(value)).append("&");
        }
        serviceUrl += "?" + StringUtils.removeEnd(sb.toString(), "&");

        // 默认为正常优先级
        priority = priority == null ? RequestPriority.NORMAL : priority;

        String finalServiceUrl = serviceUrl;
        String resp = poolExecutor.submit(new RequestTask<String>(priority) {
            @Override
            public String call() throws Exception {
                return HttpRequest.post(finalServiceUrl).body(base64Img)
                                  .header("Content-Type", "text/plain; charset=UTF-8")
                                  .timeout(ToolsContstants.ENGINE_REQUEST_TIMEEOUT).execute().body();
            }
        }).get();

        return Result.success(JSONUtil.getProp(JSONUtil.toJsonNode(resp), "body", String.class));
    }

    public Result<String> request(String trackId, String base64Img) throws Exception {
        return request(trackId, base64Img, null, null);
    }


    public Result<String> request(InputStream inputStream, String fileName, Map<String, String> params)
            throws Exception {
        String serviceUrl = StringUtils.replace(this.SERVER_URL, "{trackId}", String.valueOf(IdWokerUtil.nextId()));
        serviceUrl += "?params={params}";

        Map<String, String> paramMap = DEFAULT_PARAMS;
        if (params != null) {
            paramMap.putAll(params);
        }
        HttpHeaders headers = new HttpHeaders();
        headers.setContentType(MediaType.MULTIPART_FORM_DATA);

        ByteArrayResource contentsAsResource = new ByteArrayResource(IOUtils.toByteArray(inputStream)) {
            @Override
            public String getFilename() {
                return fileName;
            }
        };
        MultiValueMap<String, Object> form = new LinkedMultiValueMap<>();
        form.add("picFile", contentsAsResource);

        HttpEntity<MultiValueMap<String, Object>> httpEntity = new HttpEntity<>(form, headers);
        String finalServiceUrl = serviceUrl;
        ResponseEntity<String> result = poolExecutor
                .submit(new RequestTask<ResponseEntity<String>>(RequestPriority.NORMAL) {
                    @Override
                    public ResponseEntity<String> call() throws Exception {
                        String params = JSONUtil.toStrDefault(paramMap);
                        logger.debug("request, filename: {}, url: {}, params: {}", fileName, finalServiceUrl, params);
                        return restTemplate
                                .exchange(finalServiceUrl, HttpMethod.POST, httpEntity, String.class, params);
                    }
                }).get();

        if (result != null && result.getStatusCode() == HttpStatus.OK) {
            String body = result.getBody();
            return Result.success(JSONUtil.getProp(JSONUtil.toJsonNode(body), "body", String.class));
        }

        if (result != null) {
            logger.debug("调用 ocr 引擎失败, url: {}, responseCode: {}, responseBody: {}", serviceUrl,
                    result.getStatusCodeValue(), result.getBody());
        }
        return Result.failed("调用ocr引擎失败");
    }
}
